{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CrypTen - Training an Encrypted Neural Network across Workers using Jails\n",
    "\n",
    "In this tutorial we will train an encrypted neural network across different PySyft workers using Jails, an experimental mechanism into PySyft, we will be using CrypTen as a backend for SMPC. We use The MNIST dataset for this tutorial, the features will be split across alice and bob workers, each one will contain a set of pixels for every entry.\n",
    "\n",
    "Authors:\n",
    "- Ayoub Benaissa - Twitter: [@y0uben11](https://twitter.com/y0uben11)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup\n",
    "\n",
    "You should first know that you need to install PySyft from the `crypten` branch to be able to run the tutorial.\n",
    "\n",
    "To prepare the dataset, you should run `./mnist_utils.py --option features --reduced 100 --binary` using the [mnist_utils.py from CrypTen](https://github.com/facebookresearch/CrypTen/blob/b1466440bde4db3e6e1fcb1740584d35a16eda9e/tutorials/mnist_utils.py).\n",
    "\n",
    "You should also start two GridNode with IDs alice and bob listening to ports 3000 and 3001 respectively, you should update the code if you change those information. For me I cloned the [PyGrid repo](https://github.com/OpenMined/PyGrid) and started the two GridNodes by running `./run.sh --id ID --port PORT --start_local_db` in separate terminals (you need to go in *apps/node*)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import crypten\n",
    "import syft\n",
    "from time import time\n",
    "from syft.frameworks.crypten.context import run_multiworkers\n",
    "from syft.grid.clients.data_centric_fl_client import DataCentricFLClient\n",
    "\n",
    "\n",
    "torch.manual_seed(0)\n",
    "torch.set_num_threads(1)\n",
    "hook = syft.TorchHook(torch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now define the neural network that we wanna use for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define an example network\n",
    "class ExampleNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ExampleNet, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=0)\n",
    "        self.fc1 = nn.Linear(16 * 12 * 12, 100)\n",
    "        self.fc2 = nn.Linear(100, 2)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.conv1(x)\n",
    "        out = F.relu(out)\n",
    "        out = F.max_pool2d(out, 2)\n",
    "        out = out.view(-1, 16 * 12 * 12)\n",
    "        out = self.fc1(out)\n",
    "        out = F.relu(out)\n",
    "        out = self.fc2(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preparing remote workers\n",
    "\n",
    "We now connect to alice and bob via their respective ports (update the url if you are running workers in a remote machine or used a different port), then prepare and send the data to the different workers. In a real scenario, the data would be already there stored privately."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[%] Connecting to workers ...\n",
      "[+] Connected to workers\n",
      "[%] Sending labels and training data ...\n",
      "[+] Data ready\n"
     ]
    }
   ],
   "source": [
    "# Syft workers\n",
    "print(\"[%] Connecting to workers ...\")\n",
    "ALICE = DataCentricFLClient(hook, \"ws://localhost:3000\")\n",
    "BOB = DataCentricFLClient(hook, \"ws://localhost:3001\")\n",
    "print(\"[+] Connected to workers\")\n",
    "\n",
    "print(\"[%] Sending labels and training data ...\")\n",
    "# Prepare and send labels\n",
    "label_eye = torch.eye(2)\n",
    "labels = torch.load(\"/tmp/train_labels.pth\")\n",
    "labels = labels.long()\n",
    "labels_one_hot = label_eye[labels]\n",
    "labels_one_hot.tag(\"labels\")\n",
    "al_ptr = labels_one_hot.send(ALICE)\n",
    "bl_ptr = labels_one_hot.send(BOB)\n",
    "\n",
    "# Prepare and send training data\n",
    "alice_train = torch.load(\"/tmp/alice_train.pth\").tag(\"alice_train\")\n",
    "at_ptr = alice_train.send(ALICE)\n",
    "bob_train = torch.load(\"/tmp/bob_train.pth\").tag(\"bob_train\")\n",
    "bt_ptr = bob_train.send(BOB)\n",
    "\n",
    "print(\"[+] Data ready\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now instanciate a model and create a dummy input that could be forwarded through it, this is needed to build the CrypTen model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "dummy_input = torch.empty(1, 1, 28, 28)\n",
    "pytorch_model = ExampleNet()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define and run the CrypTen computation using PySyft\n",
    "\n",
    "Here we define the CrypTen computation for training the neural network, you only need to decorate it with `@run_multiworkers` to run the training across different workers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "@run_multiworkers([ALICE, BOB], master_addr=\"127.0.0.1\", model=pytorch_model, dummy_input=dummy_input)\n",
    "def run_encrypted_training():\n",
    "    rank = crypten.communicator.get().get_rank()\n",
    "    \n",
    "    # Load the labels\n",
    "    worker = syft.frameworks.crypten.get_worker_from_rank(rank)\n",
    "    labels_one_hot = worker.search(\"labels\")[0]\n",
    "\n",
    "    # Load data:\n",
    "    x_alice_enc = crypten.load(\"alice_train\", 0)\n",
    "    x_bob_enc = crypten.load(\"bob_train\", 1)\n",
    "\n",
    "    # Combine the feature sets: identical to Tutorial 3\n",
    "    x_combined_enc = crypten.cat([x_alice_enc, x_bob_enc], dim=2)\n",
    "\n",
    "    # Reshape to match the network architecture\n",
    "    x_combined_enc = x_combined_enc.unsqueeze(1)\n",
    "\n",
    "    # model is sent from the master worker\n",
    "    model.encrypt()\n",
    "    # Set train mode\n",
    "    model.train()\n",
    "\n",
    "    # Define a loss function\n",
    "    loss = crypten.nn.MSELoss()\n",
    "\n",
    "    # Define training parameters\n",
    "    learning_rate = 0.001\n",
    "    num_epochs = 2\n",
    "    batch_size = 10\n",
    "    num_batches = x_combined_enc.size(0) // batch_size\n",
    "\n",
    "    for i in range(num_epochs):\n",
    "        # Print once for readability\n",
    "        if rank == 0:\n",
    "            print(f\"Epoch {i} in progress:\")\n",
    "            pass\n",
    "\n",
    "        for batch in range(num_batches):\n",
    "            # define the start and end of the training mini-batch\n",
    "            start, end = batch * batch_size, (batch + 1) * batch_size\n",
    "\n",
    "            # construct AutogradCrypTensors out of training examples / labels\n",
    "            x_train = x_combined_enc[start:end]\n",
    "            y_batch = labels_one_hot[start:end]\n",
    "            y_train = crypten.cryptensor(y_batch, requires_grad=True)\n",
    "\n",
    "            # perform forward pass:\n",
    "            output = model(x_train)\n",
    "\n",
    "            loss_value = loss(output, y_train)\n",
    "\n",
    "            # set gradients to \"zero\"\n",
    "            model.zero_grad()\n",
    "\n",
    "            # perform backward pass:\n",
    "            loss_value.backward()\n",
    "\n",
    "            # update parameters\n",
    "            model.update_parameters(learning_rate)\n",
    "\n",
    "            # Print progress every batch:\n",
    "            batch_loss = loss_value.get_plain_text()\n",
    "            if rank == 0:\n",
    "                print(f\"\\tBatch {(batch + 1)} of {num_batches} Loss {batch_loss.item():.4f}\")\n",
    "\n",
    "    model.decrypt()\n",
    "    # printed contain all the printed strings during training\n",
    "    return printed, model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now we can start the distributed computation, `result` is a dictionary containing the result from every worker indexed by the rank of the party that they were running, so `result[0]` contains the result of the party 0 that was running in alice, `result[0][i]` contains the i'th returned value depending on how many values were returned."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[%] Starting computation\n",
      "[+] run_encrypted_training() took 45s\n",
      "Epoch 0 in progress:\n",
      "\tBatch 1 of 10 Loss 0.4638\n",
      "\tBatch 2 of 10 Loss 0.4665\n",
      "\tBatch 3 of 10 Loss 0.4063\n",
      "\tBatch 4 of 10 Loss 0.3486\n",
      "\tBatch 5 of 10 Loss 0.3314\n",
      "\tBatch 6 of 10 Loss 0.2796\n",
      "\tBatch 7 of 10 Loss 0.2768\n",
      "\tBatch 8 of 10 Loss 0.2433\n",
      "\tBatch 9 of 10 Loss 0.2458\n",
      "\tBatch 10 of 10 Loss 0.2002\n",
      "Epoch 1 in progress:\n",
      "\tBatch 1 of 10 Loss 0.1624\n",
      "\tBatch 2 of 10 Loss 0.1517\n",
      "\tBatch 3 of 10 Loss 0.1550\n",
      "\tBatch 4 of 10 Loss 0.1923\n",
      "\tBatch 5 of 10 Loss 0.1321\n",
      "\tBatch 6 of 10 Loss 0.1635\n",
      "\tBatch 7 of 10 Loss 0.2244\n",
      "\tBatch 8 of 10 Loss 0.1454\n",
      "\tBatch 9 of 10 Loss 0.1718\n",
      "\tBatch 10 of 10 Loss 0.1335\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(\"[%] Starting computation\")\n",
    "func_ts = time()\n",
    "result = run_encrypted_training()\n",
    "func_te = time()\n",
    "print(f\"[+] run_encrypted_training() took {int(func_te - func_ts)}s\")\n",
    "printed = result[0][0]\n",
    "model = result[0][1]\n",
    "print(printed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model returned is a CrypTen model, but we can always run the usual PySyft methods to share the parameters and so on, as far as the model in not encrypted."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Graph unencrypted module\n",
      "(Wrapper)>FixedPrecisionTensor>[AdditiveSharingTensor]\n",
      "\t-> [PointerTensor | me:9377222447 -> alice:29717039925]\n",
      "\t-> [PointerTensor | me:41121311542 -> bob:57548505874]\n",
      "\t*crypto provider: cp*\n"
     ]
    }
   ],
   "source": [
    "cp = syft.VirtualWorker(hook=hook, id=\"cp\")\n",
    "model.fix_prec()\n",
    "model.share(ALICE, BOB, crypto_provider=cp)\n",
    "print(model)\n",
    "print(list(model.parameters())[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations!!! - Time to Join the Community!\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!\n",
    "\n",
    "### Star PySyft on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the Repos! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "he best way to contribute to our community is to become a code contributor! At any time you can go to PySyft GitHub Issues page and filter for \"Projects\". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more \"one off\" mini-projects by searching for GitHub issues marked \"good first issue\"\n",
    "\n",
    "- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### Donate\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "[OpenMined's Open Collective Page](https://opencollective.com/openmined)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
